""" This file contains functions which compute prices.

"""
import copy
import numpy as np
import pandas as pd
from numba import njit
import statsmodels.formula.api as smf
from src.models_new import states
try:
    from src.interpolation.splines import eval_linear
except:
    from ..interpolation.splines import eval_linear


@njit
def get_price_once(
        mri, tau, state_initial, values, grid, prob_extend, const,
        r, sigma, m_0, m_1, m_2, rho_0, rho_1, rho_2, rho_3, delta, beta, seed):
    """
    While the price computation does use delta, the 'fast price components'
    do not use delta and so this must be incorporated later.

    Args:
        mri:
        tau:
        state_initial:
        values:
        grid:
        prob_extend:
        const:
        r:
        sigma:
        m_0:
        m_1:
        m_2:
        rho_0:
        rho_1:
        delta:
        beta:
        seed:

    Returns:

    """

    value = rho_0 + rho_1 * mri + rho_2 * mri * mri + rho_3 * mri * mri * mri
    np.random.seed(seed)
    state = state_initial
    contract_periods = tau
    beta_sum = 0
    match_value_list = [0.0]
    match_value_list_0 = [0.0]
    match_value_list_1 = [0.0]
    match_value_list_2 = [0.0]

    t = 0

    # Do contract simulation
    for t in range(tau):
        # Get the beta * V (outside option) component
        if t == 1:
            outside_option = eval_linear(grid, values, state)

        # Update contract
        beta_sum = beta_sum + beta ** t
        match_value = (m_0 + m_1 * mri + m_2 * state[0] * value)
        match_value_list.append(beta ** t * match_value)
        match_value_list_0.append(beta**t)
        match_value_list_1.append(beta**t * mri)
        match_value_list_2.append(beta**t * state[0])

        # Update counter
        state = states.next_state(state, const, r, sigma)
        contract_periods = contract_periods - 1
        t = t + 1

    # Get extension stuff
    prob_extend = prob_extend(state, tau, mri)
    value_tau = eval_linear(grid, values, state)
    state = states.next_state(state, const, r, sigma)
    value_tau_plus_one = eval_linear(grid, values, state)

    # Input all to find prices
    match_values = np.array(match_value_list).sum()

    price_sum = (
        delta * match_values + (1 - delta) * (
            beta * outside_option - beta ** tau * (
                (1 - prob_extend) * value_tau + prob_extend * beta * value_tau_plus_one
            )
        )
    )
    price = price_sum / beta_sum
    match_value_0 = np.array(match_value_list_0).sum() / beta_sum
    match_value_1 = np.array(match_value_list_1).sum() / beta_sum
    match_value_2 = np.array(match_value_list_2).sum() / beta_sum

    match_value_3 = (
        beta * outside_option - beta ** tau * (
            (1 - prob_extend) * value_tau + prob_extend * beta * value_tau_plus_one
        )
    )
    match_value_3 = match_value_3 / beta_sum

    return price, match_value_0, match_value_1, match_value_2, match_value_3


def get_price_all(mri, tau, state, grid, values, const, r, sigma, prob_extend, params, seeds):
    output = dict()
    output_names = ['price', 'match_value_0', 'match_value_1', 'match_value_2', 'match_value_3']
    for name in output_names:
        output[name] = dict()

    for seed in seeds:
        price_output = get_price_once(
            mri=mri,
            tau=tau,
            state_initial=state,
            grid=grid,
            values=values,
            prob_extend=prob_extend,
            const=const,
            r=r,
            sigma=sigma,
            m_0=params['m_0'],
            m_1=params['m_1'],
            m_2=params['m_2'],
            rho_0=params['rho_0'],
            rho_1=params['rho_1'],
            rho_2=params['rho_2'],
            rho_3=params['rho_3'],
            delta=params['delta'],
            beta=params['beta'],
            seed=seed
        )
        for i, name in enumerate(output_names):
            output[name][seed] = price_output[i]

    price = np.mean(list(output['price'].values()))
    match_value_0 = np.mean(list(output['match_value_0'].values()))
    match_value_1 = np.mean(list(output['match_value_1'].values()))
    match_value_2 = np.mean(list(output['match_value_2'].values()))
    match_value_3 = np.mean(list(output['match_value_3'].values()))

    return price, match_value_0, match_value_1, match_value_2, match_value_3


def build_fast_prices(df_contracts, search_grid_by_spec, search_value_by_spec, const,
                      r, sigma, prob_extend_by_spec, params, seeds):
    match_values_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        price_by_contract = list()
        match_value_0_by_contract = list()
        match_value_1_by_contract = list()
        match_value_2_by_contract = list()
        match_value_3_by_contract = list()
        for row in df_contracts[df_contracts['spec'] == spec].itertuples():
            state = np.array([row.g, row.n_l, row.n_m, row.n_h])
            (
                price,
                match_value_0,
                match_value_1,
                match_value_2,
                match_value_3,
            ) = get_price_all(
                mri=row.mri,
                tau=row.tau,
                state=state,
                grid=search_grid_by_spec[spec],
                values=search_value_by_spec[spec],
                const=const.values.T[0],
                r=r.values,
                sigma=sigma.values[0][0],
                prob_extend=prob_extend_by_spec[spec],
                params=params,
                seeds=seeds)
            price_by_contract.append(price)
            match_value_0_by_contract.append(match_value_0)
            match_value_1_by_contract.append(match_value_1)
            match_value_2_by_contract.append(match_value_2)
            match_value_3_by_contract.append(match_value_3)
        match_values_by_spec[spec] = {
            0: np.array(match_value_0_by_contract),
            1: np.array(match_value_1_by_contract),
            2: np.array(match_value_2_by_contract),
            3: np.array(match_value_3_by_contract),
            'price': np.array(price_by_contract)
        }
    return match_values_by_spec

'''
def get_fast_price_deviations(params, match_values_by_spec, prices_by_spec,
                              waterd_by_spec, values_by_spec, verbose=False):
    prices_predicted_by_spec = dict()
    deviations_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        prices_predicted_by_spec[spec] = (
            match_values_by_spec[spec][0] * params[f'm_0_{spec}'].value
            + match_values_by_spec[spec][1] * params[f'm_1_{spec}'].value
            + match_values_by_spec[spec][2] * params[f'm_2'].value * values_by_spec[spec]
            + match_values_by_spec[spec][3]
            + params['waterd_control'].value * waterd_by_spec[spec]
        )
        deviations_by_spec[spec] = prices_predicted_by_spec[spec] - prices_by_spec[spec]
    deviations = np.concatenate(list(deviations_by_spec.values()))
    if verbose:
        return deviations, prices_predicted_by_spec, deviations_by_spec
    if not verbose:
        return deviations
'''


def get_fast_price_moments(params, price_match_values_by_spec, values_by_spec, df_contracts,
                           non_myopic_dict=None):

    prices_new_by_spec = dict()
    av_prices_by_spec = dict()
    df_contracts1 = copy.copy(df_contracts)
    df_contracts1['diff_price_predicted'] = 0.0
    for spec in ['low', 'mid', 'high']:
        if non_myopic_dict is None:
            adjustment = 1.0
        else:
            adjustment = 1 - params['beta'] * (1 - non_myopic_dict['prob_exit']) * \
                         non_myopic_dict['prob_match_contracts'][spec]

        (
            df_contracts1.loc[(df_contracts1['spec'] == spec), 'diff_price_predicted']
        ) = params['delta'] * (
            price_match_values_by_spec[spec][0] * adjustment * params[f'm_0_{spec}']
            + price_match_values_by_spec[spec][1] * adjustment * params[f'm_1_{spec}']
            + price_match_values_by_spec[spec][2] * adjustment * params[f'm_2'] * values_by_spec[spec]
            # equivalent to removing the outside option on both sides...
            # + match_values_by_spec[spec][3]
        )

        # Get the actual prices (some moments depend on the actual prices do diff)
        prices_new_by_spec[f'price_{spec}'] = (
            params['delta'] * price_match_values_by_spec[spec][0] * adjustment * params[f'm_0_{spec}']
            + params['delta'] * price_match_values_by_spec[spec][1] * adjustment * params[f'm_1_{spec}']
            + params['delta'] * price_match_values_by_spec[spec][2] * adjustment * params[f'm_2'] * values_by_spec[spec]
            + (1 - params['delta']) * price_match_values_by_spec[spec][3]
        )
        av_prices_by_spec[f'price_{spec}'] = prices_new_by_spec[f'price_{spec}'].mean()

    reg = smf.ols(
        formula='diff_price_predicted ~ C(spec, Treatment(reference="low")) '
                '+ mri : C(spec, Treatment(reference="low")) + g : value',
        data=df_contracts1
    ).fit()

    return pd.concat([reg.params, pd.Series(av_prices_by_spec)]), prices_new_by_spec
